# This script evaluates whether imputation using missForest or MICE gives the most similar performance to that which was identified from the complete-data analysis. 
# The imputed training datasets used in this script are from the "Imputation_method_comparison.txt" script. 
# The best algorithm identified from the initial model development stage was used for this - a preschool model using SVM using a linear kernel
# The performance of all models were evaluated on the same validation dataset used during the initial model development stage. Each imputed training dataset was standardised separately and the same properties applied to the test set. 
# Python version 3.6.8 is used

# Imports
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from time import time
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, StratifiedKFold, RandomizedSearchCV
from sklearn.model_selection import cross_val_score
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
from sklearn.metrics import balanced_accuracy_score

# Set working directory
os.chdir("/../../")

############################################
### missForest imputed model development ###
############################################
# Import cleaned, unstandardised, missForest imputed preschool dataset  - data found in IOWBC_imputed_data.xlsx, sheet: "missForest complete data"
data = pd.read_csv("Imputed_missForest_preschool_complete_training_dataset_365IDs.csv", index_col=False)
# 365 Ids, 12 predictors

X_train = data.iloc[:,1:13]
y_train = data.iloc[:,13]

# Standardise training and test sets
scaler = StandardScaler()
cont_train = pd.DataFrame(scaler.fit_transform(X_train.iloc[:,0:5]), columns=('Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1', 'SDS_BMI_4'))
cat_train = X_train.iloc[:,5:]
SX_train = pd.concat([cont_train, cat_train.reset_index(drop=True)], axis=1)

# Import unstandardised preschool test data - data found in IOWBC_training_test_data.xlsx, sheet: "Preschool test set"
test = pd.read_csv("Preschool_test_dataset_183IDs.csv", index_col=False) 

# Split test data into features and outcome
X_test = test.drop(['Study_ID','Asthma_10YR'], axis=1)
y_test = test['Asthma_10YR']

# Standardise test dataset
cont_test = pd.DataFrame(scaler.transform(X_test.iloc[:,0:5]), columns=('Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1', 'SDS_BMI_4'))
cat_test = X_test.iloc[:,5:]
SX_test = pd.concat([cont_test, cat_test.reset_index(drop=True)], axis=1)


### SVM - linear ###
# Define a svm classifier
clf = SVC(kernel='linear', probability=True, random_state=123)

##### Random search #####
C_range = np.logspace(-3,2,100)
param_grid = dict(C=C_range)

# Run randomized search
random_search = RandomizedSearchCV(clf, scoring='balanced_accuracy',param_distributions=param_grid,
                                    n_iter=100, n_jobs=-1, cv=StratifiedKFold(5))
start = time()
random_search.fit(SX_train, y_train)
RStime = (time() - start)
# 219.09028911590576

best_parameters = random_search.best_params_
print(best_parameters)
# {'C': 0.3351602650938841}

best_score = random_search.best_score_
print(best_score)
# 0.6926756804838997

##### Grid search #####
C_range = np.arange(0.01, 0.51, 0.01)
param_grid = dict(C=C_range)
grid_search = GridSearchCV(clf, param_grid=param_grid, scoring='balanced_accuracy', n_jobs=16, cv=StratifiedKFold(5))
start = time()
grid_search.fit(SX_train, y_train)
GStime = (time() - start)

# best parameters
best_parameters = grid_search.best_params_
print(best_parameters)
#'C': 0.32

best_score = grid_search.best_score_
print(best_score)
# 0.6926756804838997

# Fit optimised model
clf = SVC(C=0.32, kernel='linear', probability=True, random_state=123)
clf.fit(SX_train,y_train)

# Predicting the train set results
y_train_pred = clf.predict(SX_train)
cm_train = confusion_matrix(y_train, y_train_pred)
print(cm_train)
# [302  12] [ 30  21]

train_report = classification_report(y_train, y_train_pred)
print (train_report)
#              precision    recall  f1-score   support
#
#           0       0.91      0.96      0.93       314
#           1       0.64      0.41      0.50        51
#
#    accuracy                           0.88       365
#   macro avg       0.77      0.69      0.72       365
#weighted avg       0.87      0.88      0.87       365


accuracy_score(y_train, y_train_pred)
#0.8849315068493151

balanced_accuracy_score(y_train, y_train_pred)
#0.6867740726863993

sensitivity =  cm_train[1,1]/(cm_train[1,0]+cm_train[1,1])								
print(sensitivity)
#0.4117647058823529

specificity = cm_train[0,0]/(cm_train[0,0]+cm_train[0,1])									
print(specificity)
#0.9617834394904459

PPV = cm_train[1,1]/(cm_train[1,1]+cm_train[0,1])	
print(PPV)
#0.6363636363636364

NPV = cm_train[0,0]/(cm_train[0,0]+cm_train[1,0])
print(NPV)
#0.9096385542168675

LRp = sensitivity/(1-specificity)
print(LRp)
#10.774509803921571

LRn = (1-sensitivity)/specificity
print(LRn)
#0.6116088819633814

#  AUC: 
AUC_train = roc_auc_score(y_train, y_train_pred)
print(AUC_train)
#0.6867740726863993

probs = clf.predict_proba(SX_train)
preds = probs[:,1]
ROCAUC_train = roc_auc_score(y_train, preds)
print(ROCAUC_train)
#0.8513175971025354

#Predict the response for test dataset
y_pred = clf.predict(SX_test)

cm_test = confusion_matrix(y_test, y_pred)	
print (cm_test)
#[[138  20] [ 10  15]

test_report = classification_report(y_test, y_pred)
print (test_report)
#              precision    recall  f1-score   support
#
#           0       0.93      0.87      0.90       158
#           1       0.43      0.60      0.50        25
#
#    accuracy                           0.84       183
#   macro avg       0.68      0.74      0.70       183
#weighted avg       0.86      0.84      0.85       183



accuracy_score(y_test, y_pred)
#0.8360655737704918

balanced_accuracy_score(y_test, y_pred)
#0.7367088607594936

sensitivity =  cm_test[1,1]/(cm_test[1,0]+cm_test[1,1])								
print(sensitivity)
#0.6

specificity = cm_test[0,0]/(cm_test[0,0]+cm_test[0,1])									
print(specificity)
#0.8734177215189873

PPV = cm_test[1,1]/(cm_test[1,1]+cm_test[0,1])	
print(PPV)
#0.42857142857142855

NPV = cm_test[0,0]/(cm_test[0,0]+cm_test[1,0])
print(NPV)
#0.9324324324324325

LRp = sensitivity/(1-specificity)
print(LRp)
#4.739999999999999

LRn = (1-sensitivity)/specificity
print(LRn)
#0.4579710144927536

AUC_test = roc_auc_score(y_test, y_pred)
print(AUC_test)
#0.7367088607594937

probs = clf.predict_proba(SX_test)
preds = probs[:,1]
ROCAUC_test = roc_auc_score(y_test, preds)
print(ROCAUC_test)
#0.7941772151898734


############################################
### missForest imputed model development ###
############################################
# Import cleaned, unstandardised, MICE imputed preschool dataset # save dataset - data found in IOWBC_imputed_data.xlsx, sheet: "MICE complete data"
data = pd.read_csv("Imputed_MICE_preschool_complete_training_dataset_365IDs.csv", index_col=False)
del data['Unnamed: 0']
data = data.rename(columns={'complete_data.Asthma_10YR': 'Asthma_10YR'})

# 365 Ids, 12 predictors

X_train = data.iloc[:,0:12]
y_train = data.iloc[:,12]

# Standardise training sets
scaler = StandardScaler()
cont_train = pd.DataFrame(scaler.fit_transform(X_train.iloc[:,0:5]), columns=('Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1', 'SDS_BMI_4'))
cat_train = X_train.iloc[:,5:]
SX_train = pd.concat([cont_train, cat_train.reset_index(drop=True)], axis=1)

# Import test data
test = pd.read_csv("/scratch/dk2e18/Asthma_Prediction_Model/Oversampling/Test_data_preschool_dataset_183IDs.csv", index_col=False)

# Split test data into features and outcome
X_test = test.drop(['Study_ID','Asthma_10YR'], axis=1)
y_test = test['Asthma_10YR']

# Standar test set
cont_test = pd.DataFrame(scaler.transform(X_test.iloc[:,0:5]), columns=('Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1', 'SDS_BMI_4'))
cat_test = X_test.iloc[:,5:]
SX_test = pd.concat([cont_test, cat_test.reset_index(drop=True)], axis=1)


### SVM - linear ###
# Define a svm classifier
clf = SVC(kernel='linear', probability=True, random_state=123)

##### Random search #####
C_range = np.logspace(-3,2,100)
param_grid = dict(C=C_range)

# Run randomized search
random_search = RandomizedSearchCV(clf, scoring='balanced_accuracy',param_distributions=param_grid,
                                    n_iter=100, n_jobs=-1, cv=StratifiedKFold(5))
start = time()
random_search.fit(SX_train, y_train)
RStime = (time() - start)
print(RStime)
# 3.326319456100464

best_parameters = random_search.best_params_
print(best_parameters)
# {'C': 0.13219411484660287}

best_score = random_search.best_score_
print(best_score)
# 0.7035936666073651

##### Grid search #####
C_range = np.arange(0.01, 0.51, 0.01)
param_grid = dict(C=C_range)
grid_search = GridSearchCV(clf, param_grid=param_grid, scoring='balanced_accuracy', n_jobs=16, cv=StratifiedKFold(5))
start = time()
grid_search.fit(SX_train, y_train)
GStime = (time() - start)

# best parameters
best_parameters = grid_search.best_params_
print(best_parameters)
#'C': 0.13

best_score = grid_search.best_score_
print(best_score)
# 0.7035936666073651

# Fit optimised model
clf = SVC(C=0.13, kernel='linear', probability=True, random_state=123)
clf.fit(SX_train,y_train)

# Predicting the train set results
y_train_pred = clf.predict(SX_train)
cm_train = confusion_matrix(y_train, y_train_pred)
print(cm_train)
# [305   9] [ 29  22]]

train_report = classification_report(y_train, y_train_pred)
print (train_report)
#              precision    recall  f1-score   support
#
#           0       0.91      0.97      0.94       314
#           1       0.71      0.43      0.54        51
#
#    accuracy                           0.90       365
#   macro avg       0.81      0.70      0.74       365
#weighted avg       0.88      0.90      0.88       365



accuracy_score(y_train, y_train_pred)
#0.8958904109589041

balanced_accuracy_score(y_train, y_train_pred)
#0.7013550643187212

sensitivity =  cm_train[1,1]/(cm_train[1,0]+cm_train[1,1])								
print(sensitivity)
#0.43137254901960786

specificity = cm_train[0,0]/(cm_train[0,0]+cm_train[0,1])									
print(specificity)
#0.9713375796178344

PPV = cm_train[1,1]/(cm_train[1,1]+cm_train[0,1])	
print(PPV)
#0.7096774193548387

NPV = cm_train[0,0]/(cm_train[0,0]+cm_train[1,0])
print(NPV)
#0.9131736526946108

LRp = sensitivity/(1-specificity)
print(LRp)
#15.050108932461878

LRn = (1-sensitivity)/specificity
print(LRn)
#0.5854066216650594

#  AUC: 
AUC_train = roc_auc_score(y_train, y_train_pred)
print(AUC_train)
#0.7013550643187212

probs = clf.predict_proba(SX_train)
preds = probs[:,1]
ROCAUC_train = roc_auc_score(y_train, preds)
print(ROCAUC_train)
#0.8837267391032846

#Predict the response for test dataset
y_pred = clf.predict(SX_test)

cm_test = confusion_matrix(y_test, y_pred)	
print (cm_test)
#[145  13] [ 15  10]

test_report = classification_report(y_test, y_pred)
print (test_report)
#              precision    recall  f1-score   support
#
#           0       0.91      0.92      0.91       158
#           1       0.43      0.40      0.42        25
#
#    accuracy                           0.85       183
#   macro avg       0.67      0.66      0.66       183
#weighted avg       0.84      0.85      0.84       183


accuracy_score(y_test, y_pred)
#0.8469945355191257

balanced_accuracy_score(y_test, y_pred)
#0.6588607594936708

sensitivity =  cm_test[1,1]/(cm_test[1,0]+cm_test[1,1])								
print(sensitivity)
#0.4

specificity = cm_test[0,0]/(cm_test[0,0]+cm_test[0,1])									
print(specificity)
#0.9177215189873418

PPV = cm_test[1,1]/(cm_test[1,1]+cm_test[0,1])	
print(PPV)
#0.43478260869565216

NPV = cm_test[0,0]/(cm_test[0,0]+cm_test[1,0])
print(NPV)
#0.90625

LRp = sensitivity/(1-specificity)
print(LRp)
#4.861538461538462

LRn = (1-sensitivity)/specificity
print(LRn)
#0.6537931034482758

AUC_test = roc_auc_score(y_test, y_pred)
print(AUC_test)
#0.6588607594936708

probs = clf.predict_proba(SX_test)
preds = probs[:,1]
ROCAUC_test = roc_auc_score(y_test, preds)
print(ROCAUC_test)
#0.7739240506329115